{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "# Word2Vec: Skipgram Model\n",
    "-------------------------------------\n",
    "\n",
    "In this example, we will download and preprocess the movie review data.\n",
    "\n",
    "From this data set we will compute/fit the skipgram model of the Word2Vec Algorithm\n",
    "\n",
    "Skipgram: based on predicting the surrounding words from the\n",
    "\n",
    "Ex sentence \"the cat in the hat\"\n",
    " - context word:  [\"hat\"]\n",
    " - target words: [\"the\", \"cat\", \"in\", \"the\"]\n",
    " - context-target pairs: (\"hat\", \"the\"), (\"hat\", \"cat\"), (\"hat\", \"in\"), (\"hat\", \"the\")\n",
    "\n",
    "We start by loading the necessary libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import random\n",
    "import os\n",
    "import string\n",
    "import requests\n",
    "import collections\n",
    "import io\n",
    "import tarfile\n",
    "import gzip\n",
    "from nltk.corpus import stopwords\n",
    "from tensorflow.python.framework import ops\n",
    "ops.reset_default_graph()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Start a computational graph session."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "sess = tf.Session()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Declare model parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "batch_size = 100         # How many sets of words to train on at once.\n",
    "embedding_size = 100    # The embedding size of each word to train.\n",
    "vocabulary_size = 5000 # How many words we will consider for training.\n",
    "generations = 100000    # How many iterations we will perform the training on.\n",
    "print_loss_every = 500  # Print the loss every so many iterations\n",
    "\n",
    "num_sampled = int(batch_size/2) # Number of negative examples to sample.\n",
    "window_size = 2         # How many words to consider left and right."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "We will remove stop words and create a test validation set of words."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# Declare stop words\n",
    "stops = stopwords.words('english')\n",
    "\n",
    "# We pick five test words. We are expecting synonyms to appear\n",
    "print_valid_every = 10000\n",
    "valid_words = ['cliche', 'love', 'hate', 'silly', 'sad']\n",
    "# Later we will have to transform these into indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Next, we load the movie review data.  We check if the data was downloaded, and not, download and save it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def load_movie_data():\n",
    "    save_folder_name = 'temp'\n",
    "    pos_file = os.path.join(save_folder_name, 'rt-polaritydata', 'rt-polarity.pos')\n",
    "    neg_file = os.path.join(save_folder_name, 'rt-polaritydata', 'rt-polarity.neg')\n",
    "\n",
    "    # Check if files are already downloaded\n",
    "    if not os.path.exists(os.path.join(save_folder_name, 'rt-polaritydata')):\n",
    "        movie_data_url = 'http://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz'\n",
    "\n",
    "        # Save tar.gz file\n",
    "        req = requests.get(movie_data_url, stream=True)\n",
    "        with open(os.path.join(save_folder_name,'temp_movie_review_temp.tar.gz'), 'wb') as f:\n",
    "            for chunk in req.iter_content(chunk_size=1024):\n",
    "                if chunk:\n",
    "                    f.write(chunk)\n",
    "                    f.flush()\n",
    "        # Extract tar.gz file into temp folder\n",
    "        tar = tarfile.open(os.path.join(save_folder_name,'temp_movie_review_temp.tar.gz'), \"r:gz\")\n",
    "        tar.extractall(path='temp')\n",
    "        tar.close()\n",
    "\n",
    "    pos_data = []\n",
    "    with open(pos_file, 'r', encoding='latin-1') as f:\n",
    "        for line in f:\n",
    "            pos_data.append(line.encode('ascii',errors='ignore').decode())\n",
    "    f.close()\n",
    "    pos_data = [x.rstrip() for x in pos_data]\n",
    "\n",
    "    neg_data = []\n",
    "    with open(neg_file, 'r', encoding='latin-1') as f:\n",
    "        for line in f:\n",
    "            neg_data.append(line.encode('ascii',errors='ignore').decode())\n",
    "    f.close()\n",
    "    neg_data = [x.rstrip() for x in neg_data]\n",
    "    \n",
    "    texts = pos_data + neg_data\n",
    "    target = [1]*len(pos_data) + [0]*len(neg_data)\n",
    "    \n",
    "    return(texts, target)\n",
    "\n",
    "\n",
    "texts, target = load_movie_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Now we create a function that normalizes/cleans the text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# Normalize text\n",
    "def normalize_text(texts, stops):\n",
    "    # Lower case\n",
    "    texts = [x.lower() for x in texts]\n",
    "\n",
    "    # Remove punctuation\n",
    "    texts = [''.join(c for c in x if c not in string.punctuation) for x in texts]\n",
    "\n",
    "    # Remove numbers\n",
    "    texts = [''.join(c for c in x if c not in '0123456789') for x in texts]\n",
    "\n",
    "    # Remove stopwords\n",
    "    texts = [' '.join([word for word in x.split() if word not in (stops)]) for x in texts]\n",
    "\n",
    "    # Trim extra whitespace\n",
    "    texts = [' '.join(x.split()) for x in texts]\n",
    "    \n",
    "    return(texts)\n",
    "    \n",
    "texts = normalize_text(texts, stops)\n",
    "\n",
    "# Texts must contain at least 3 words\n",
    "target = [target[ix] for ix, x in enumerate(texts) if len(x.split()) > 2]\n",
    "texts = [x for x in texts if len(x.split()) > 2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "With the normalized movie reviews, we now build a dictionary of words."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# Build dictionary of words\n",
    "def build_dictionary(sentences, vocabulary_size):\n",
    "    # Turn sentences (list of strings) into lists of words\n",
    "    split_sentences = [s.split() for s in sentences]\n",
    "    words = [x for sublist in split_sentences for x in sublist]\n",
    "    \n",
    "    # Initialize list of [word, word_count] for each word, starting with unknown\n",
    "    count = [['RARE', -1]]\n",
    "    \n",
    "    # Now add most frequent words, limited to the N-most frequent (N=vocabulary size)\n",
    "    count.extend(collections.Counter(words).most_common(vocabulary_size-1))\n",
    "    \n",
    "    # Now create the dictionary\n",
    "    word_dict = {}\n",
    "    # For each word, that we want in the dictionary, add it, then make it\n",
    "    # the value of the prior dictionary length\n",
    "    for word, word_count in count:\n",
    "        word_dict[word] = len(word_dict)\n",
    "    \n",
    "    return(word_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "With the above dictionary, we can turn text data into lists of integers from such dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "def text_to_numbers(sentences, word_dict):\n",
    "    # Initialize the returned data\n",
    "    data = []\n",
    "    for sentence in sentences:\n",
    "        sentence_data = []\n",
    "        # For each word, either use selected index or rare word index\n",
    "        for word in sentence.split(' '):\n",
    "            if word in word_dict:\n",
    "                word_ix = word_dict[word]\n",
    "            else:\n",
    "                word_ix = 0\n",
    "            sentence_data.append(word_ix)\n",
    "        data.append(sentence_data)\n",
    "    return(data)\n",
    "\n",
    "# Build our data set and dictionaries\n",
    "word_dictionary = build_dictionary(texts, vocabulary_size)\n",
    "word_dictionary_rev = dict(zip(word_dictionary.values(), word_dictionary.keys()))\n",
    "text_data = text_to_numbers(texts, word_dictionary)\n",
    "\n",
    "# Get validation word keys\n",
    "valid_examples = [word_dictionary[x] for x in valid_words]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Let us now build a function that will generate random data points from our text and parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# Generate data randomly (N words behind, target, N words ahead)\n",
    "def generate_batch_data(sentences, batch_size, window_size, method='skip_gram'):\n",
    "    # Fill up data batch\n",
    "    batch_data = []\n",
    "    label_data = []\n",
    "    while len(batch_data) < batch_size:\n",
    "        # select random sentence to start\n",
    "        rand_sentence = np.random.choice(sentences)\n",
    "        # Generate consecutive windows to look at\n",
    "        window_sequences = [rand_sentence[max((ix-window_size),0):(ix+window_size+1)] for ix, x in enumerate(rand_sentence)]\n",
    "        # Denote which element of each window is the center word of interest\n",
    "        label_indices = [ix if ix<window_size else window_size for ix,x in enumerate(window_sequences)]\n",
    "        \n",
    "        # Pull out center word of interest for each window and create a tuple for each window\n",
    "        if method=='skip_gram':\n",
    "            batch_and_labels = [(x[y], x[:y] + x[(y+1):]) for x,y in zip(window_sequences, label_indices)]\n",
    "            # Make it in to a big list of tuples (target word, surrounding word)\n",
    "            tuple_data = [(x, y_) for x,y in batch_and_labels for y_ in y]\n",
    "        elif method=='cbow':\n",
    "            batch_and_labels = [(x[:y] + x[(y+1):], x[y]) for x,y in zip(window_sequences, label_indices)]\n",
    "            # Make it in to a big list of tuples (target word, surrounding word)\n",
    "            tuple_data = [(x_, y) for x,y in batch_and_labels for x_ in x]\n",
    "        else:\n",
    "            raise ValueError('Method {} not implemented yet.'.format(method))\n",
    "            \n",
    "        # extract batch and labels\n",
    "        batch, labels = [list(x) for x in zip(*tuple_data)]\n",
    "        batch_data.extend(batch[:batch_size])\n",
    "        label_data.extend(labels[:batch_size])\n",
    "    # Trim batch and label at the end\n",
    "    batch_data = batch_data[:batch_size]\n",
    "    label_data = label_data[:batch_size]\n",
    "    \n",
    "    # Convert to numpy array\n",
    "    batch_data = np.array(batch_data)\n",
    "    label_data = np.transpose(np.array([label_data]))\n",
    "    \n",
    "    return(batch_data, label_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Next we define our model and placeholders."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# Define Embeddings:\n",
    "embeddings = tf.Variable(tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))\n",
    "\n",
    "# NCE loss parameters\n",
    "nce_weights = tf.Variable(tf.truncated_normal([vocabulary_size, embedding_size],\n",
    "                                               stddev=1.0 / np.sqrt(embedding_size)))\n",
    "nce_biases = tf.Variable(tf.zeros([vocabulary_size]))\n",
    "\n",
    "# Create data/target placeholders\n",
    "x_inputs = tf.placeholder(tf.int32, shape=[batch_size])\n",
    "y_target = tf.placeholder(tf.int32, shape=[batch_size, 1])\n",
    "valid_dataset = tf.constant(valid_examples, dtype=tf.int32)\n",
    "\n",
    "# Lookup the word embedding:\n",
    "embed = tf.nn.embedding_lookup(embeddings, x_inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Here is our loss function, optimizer, cosine similarity, and initialization of the model variables.\n",
    "\n",
    "For the loss function we will minimize the average of the NCE loss (noise-contrastive estimation)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# Get loss from prediction\n",
    "loss = tf.reduce_mean(tf.nn.nce_loss(weights=nce_weights,\n",
    "                                     biases=nce_biases,\n",
    "                                     labels=y_target,\n",
    "                                     inputs=embed,\n",
    "                                     num_sampled=num_sampled,\n",
    "                                     num_classes=vocabulary_size))\n",
    "\n",
    "# Create optimizer\n",
    "optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0).minimize(loss)\n",
    "\n",
    "# Cosine similarity between words\n",
    "norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))\n",
    "normalized_embeddings = embeddings / norm\n",
    "valid_embeddings = tf.nn.embedding_lookup(normalized_embeddings, valid_dataset)\n",
    "similarity = tf.matmul(valid_embeddings, normalized_embeddings, transpose_b=True)\n",
    "\n",
    "\n",
    "#Add variable initializer.\n",
    "init = tf.global_variables_initializer()\n",
    "sess.run(init)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sim_init = sess.run(similarity)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Now we can train our skip-gram model.\n",
    "\n",
    "> Note that we have the line: `nearest = (-sim[j, :]).argsort()[1:top_k+1]` below. The negative of the similarity matrix is used because `argsort()` sorts the values from least to greatest.  Since we want to take the greatest numbers, we sort in the opposite direction by taking the negative of the similarity matrix, then calling the `argsort()` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss at step 1000 : 4.620365619659424\n",
      "Loss at step 2000 : 3.711085319519043\n",
      "Loss at step 3000 : 4.3959550857543945\n",
      "Loss at step 4000 : 4.229242324829102\n",
      "Loss at step 5000 : 4.189347743988037\n",
      "Loss at step 6000 : 3.6959874629974365\n",
      "Loss at step 7000 : 4.364022254943848\n",
      "Loss at step 8000 : 4.148524284362793\n",
      "Loss at step 9000 : 3.7397570610046387\n",
      "Loss at step 10000 : 4.326292514801025\n",
      "Nearest to cliche: particularly, movie, artists, successful, understanding,\n",
      "Nearest to love: screen, leave, street, potential, RARE,\n",
      "Nearest to hate: joke, es, questions, creativity, con,\n",
      "Nearest to silly: eccentric, bag, glass, single, things,\n",
      "Nearest to sad: meditation, movie, RARE, whatever, fantastic,\n",
      "Loss at step 11000 : 3.664271593093872\n",
      "Loss at step 12000 : 4.203312397003174\n",
      "Loss at step 13000 : 3.925948143005371\n",
      "Loss at step 14000 : 3.9482924938201904\n",
      "Loss at step 15000 : 4.0689568519592285\n",
      "Loss at step 16000 : 4.345786094665527\n",
      "Loss at step 17000 : 3.8185791969299316\n",
      "Loss at step 18000 : 4.215057849884033\n",
      "Loss at step 19000 : 4.00248908996582\n",
      "Loss at step 20000 : 3.9021575450897217\n",
      "Nearest to cliche: artists, successful, enough, movie, understanding,\n",
      "Nearest to love: screen, potential, leave, come, faith,\n",
      "Nearest to hate: es, creativity, questions, adults, camp,\n",
      "Nearest to silly: bag, things, fly, glass, eccentric,\n",
      "Nearest to sad: try, meditation, movie, whatever, murder,\n",
      "Loss at step 21000 : 3.548168420791626\n",
      "Loss at step 22000 : 3.707749843597412\n",
      "Loss at step 23000 : 4.022873401641846\n",
      "Loss at step 24000 : 3.940373420715332\n",
      "Loss at step 25000 : 3.5750863552093506\n",
      "Loss at step 26000 : 3.937954902648926\n",
      "Loss at step 27000 : 3.6091718673706055\n",
      "Loss at step 28000 : 4.512173652648926\n",
      "Loss at step 29000 : 4.0766191482543945\n",
      "Loss at step 30000 : 4.037301540374756\n",
      "Nearest to cliche: artists, enough, understanding, successful, moments,\n",
      "Nearest to love: potential, screen, cult, RARE, war,\n",
      "Nearest to hate: es, creativity, exactly, questions, camp,\n",
      "Nearest to silly: bag, hybrid, things, cinematic, fly,\n",
      "Nearest to sad: film, movie, try, fantastic, faith,\n",
      "Loss at step 31000 : 3.926762104034424\n",
      "Loss at step 32000 : 3.800074815750122\n",
      "Loss at step 33000 : 3.9192094802856445\n",
      "Loss at step 34000 : 3.9992804527282715\n",
      "Loss at step 35000 : 3.567578077316284\n",
      "Loss at step 36000 : 3.527355909347534\n",
      "Loss at step 37000 : 3.8494346141815186\n",
      "Loss at step 38000 : 3.834221124649048\n",
      "Loss at step 39000 : 3.8048064708709717\n",
      "Loss at step 40000 : 3.8615174293518066\n",
      "Nearest to cliche: artists, enough, RARE, successful, understanding,\n",
      "Nearest to love: cult, potential, screen, RARE, young,\n",
      "Nearest to hate: es, creativity, exactly, camp, adults,\n",
      "Nearest to silly: hybrid, cinematic, bag, spielberg, fly,\n",
      "Nearest to sad: movie, starts, faith, fantastic, try,\n",
      "Loss at step 41000 : 3.39894437789917\n",
      "Loss at step 42000 : 3.7645909786224365\n",
      "Loss at step 43000 : 3.580610990524292\n",
      "Loss at step 44000 : 3.842486619949341\n",
      "Loss at step 45000 : 3.343255043029785\n",
      "Loss at step 46000 : 4.337518692016602\n",
      "Loss at step 47000 : 3.9678456783294678\n",
      "Loss at step 48000 : 3.0952281951904297\n",
      "Loss at step 49000 : 3.8245849609375\n",
      "Loss at step 50000 : 4.1148271560668945\n",
      "Nearest to cliche: RARE, enough, artists, understanding, moments,\n",
      "Nearest to love: cult, young, potential, screen, faith,\n",
      "Nearest to hate: es, creativity, exactly, camp, teenage,\n",
      "Nearest to silly: hybrid, cinematic, goes, inspiring, spielberg,\n",
      "Nearest to sad: starts, faith, beautiful, movie, subtle,\n",
      "Loss at step 51000 : 3.841761589050293\n",
      "Loss at step 52000 : 3.6969189643859863\n",
      "Loss at step 53000 : 3.572815179824829\n",
      "Loss at step 54000 : 4.168234348297119\n",
      "Loss at step 55000 : 4.432800769805908\n",
      "Loss at step 56000 : 3.983534574508667\n",
      "Loss at step 57000 : 3.5007388591766357\n",
      "Loss at step 58000 : 3.7342286109924316\n",
      "Loss at step 59000 : 4.190605640411377\n",
      "Loss at step 60000 : 3.747537136077881\n",
      "Nearest to cliche: enough, artists, understanding, successful, monty,\n",
      "Nearest to love: cult, potential, faith, beneath, screen,\n",
      "Nearest to hate: es, exactly, creativity, camp, adults,\n",
      "Nearest to silly: cinematic, hybrid, goes, romantic, inspiring,\n",
      "Nearest to sad: movie, starts, beautiful, faith, approach,\n",
      "Loss at step 61000 : 3.109525680541992\n",
      "Loss at step 62000 : 3.5799553394317627\n",
      "Loss at step 63000 : 3.7504074573516846\n",
      "Loss at step 64000 : 3.833890914916992\n",
      "Loss at step 65000 : 3.561004638671875\n",
      "Loss at step 66000 : 2.9587860107421875\n",
      "Loss at step 67000 : 4.141056060791016\n",
      "Loss at step 68000 : 4.215058326721191\n",
      "Loss at step 69000 : 4.030611991882324\n",
      "Loss at step 70000 : 2.7594995498657227\n",
      "Nearest to cliche: enough, artists, understanding, moments, movie,\n",
      "Nearest to love: cult, faith, street, potential, price,\n",
      "Nearest to hate: es, creativity, everyone, camp, adults,\n",
      "Nearest to silly: goes, cinematic, hybrid, spielberg, inspiring,\n",
      "Nearest to sad: movie, starts, beautiful, approach, meditation,\n",
      "Loss at step 71000 : 3.3502533435821533\n",
      "Loss at step 72000 : 4.011401176452637\n",
      "Loss at step 73000 : 4.242837905883789\n",
      "Loss at step 74000 : 3.6275765895843506\n",
      "Loss at step 75000 : 4.354449272155762\n",
      "Loss at step 76000 : 3.369835138320923\n",
      "Loss at step 77000 : 4.044325351715088\n",
      "Loss at step 78000 : 4.546400547027588\n",
      "Loss at step 79000 : 3.2915635108947754\n",
      "Loss at step 80000 : 3.6428585052490234\n",
      "Nearest to cliche: enough, artists, understanding, monty, successful,\n",
      "Nearest to love: cult, reno, potential, faith, little,\n",
      "Nearest to hate: es, creativity, everyone, con, camp,\n",
      "Nearest to silly: goes, cinematic, spielberg, hybrid, inspiring,\n",
      "Nearest to sad: movie, starts, beautiful, approach, cheesy,\n",
      "Loss at step 81000 : 3.881168842315674\n",
      "Loss at step 82000 : 4.05750846862793\n",
      "Loss at step 83000 : 4.228786468505859\n",
      "Loss at step 84000 : 4.043832778930664\n",
      "Loss at step 85000 : 3.952359676361084\n",
      "Loss at step 86000 : 3.678485155105591\n",
      "Loss at step 87000 : 3.2399590015411377\n",
      "Loss at step 88000 : 3.8945517539978027\n",
      "Loss at step 89000 : 3.8429043292999268\n",
      "Loss at step 90000 : 4.262929916381836\n",
      "Nearest to cliche: enough, artists, understanding, monty, edge,\n",
      "Nearest to love: cult, enjoy, reno, beneath, faith,\n",
      "Nearest to hate: creativity, es, everyone, con, generally,\n",
      "Nearest to silly: goes, spielberg, cinematic, romantic, odd,\n",
      "Nearest to sad: starts, beautiful, daring, movie, faith,\n",
      "Loss at step 91000 : 4.095947742462158\n",
      "Loss at step 92000 : 3.5220530033111572\n",
      "Loss at step 93000 : 3.7437546253204346\n",
      "Loss at step 94000 : 4.10945987701416\n",
      "Loss at step 95000 : 3.658798933029175\n",
      "Loss at step 96000 : 3.8866307735443115\n",
      "Loss at step 97000 : 3.4304962158203125\n",
      "Loss at step 98000 : 3.6770119667053223\n",
      "Loss at step 99000 : 3.936981439590454\n",
      "Loss at step 100000 : 3.5797817707061768\n",
      "Nearest to cliche: artists, monty, understanding, successful, edge,\n",
      "Nearest to love: cult, essentially, enjoy, hold, cautionary,\n",
      "Nearest to hate: everyone, es, creativity, con, generally,\n",
      "Nearest to silly: odd, goes, action, romantic, austin,\n",
      "Nearest to sad: starts, match, cheesy, faith, characters,\n"
     ]
    }
   ],
   "source": [
    "# Run the skip gram model.\n",
    "loss_vec = []\n",
    "loss_x_vec = []\n",
    "for i in range(generations):\n",
    "    batch_inputs, batch_labels = generate_batch_data(text_data, batch_size, window_size)\n",
    "    feed_dict = {x_inputs : batch_inputs, y_target : batch_labels}\n",
    "\n",
    "    # Run the train step\n",
    "    sess.run(optimizer, feed_dict=feed_dict)\n",
    "\n",
    "    # Return the loss\n",
    "    if (i+1) % print_loss_every == 0:\n",
    "        loss_val = sess.run(loss, feed_dict=feed_dict)\n",
    "        loss_vec.append(loss_val)\n",
    "        loss_x_vec.append(i+1)\n",
    "        print(\"Loss at step {} : {}\".format(i+1, loss_val))\n",
    "      \n",
    "    # Validation: Print some random words and top 5 related words\n",
    "    if (i+1) % print_valid_every == 0:\n",
    "        sim = sess.run(similarity)\n",
    "        for j in range(len(valid_words)):\n",
    "            valid_word = word_dictionary_rev[valid_examples[j]]\n",
    "            top_k = 5 # number of nearest neighbors\n",
    "            nearest = (-sim[j, :]).argsort()[1:top_k+1]\n",
    "            log_str = \"Nearest to {}:\".format(valid_word)\n",
    "            for k in range(top_k):\n",
    "                close_word = word_dictionary_rev[nearest[k]]\n",
    "                score = sim[j,nearest[k]]\n",
    "                log_str = \"%s %s,\" % (log_str, close_word)\n",
    "            print(log_str)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
